import base64
import unittest

from ..bitcoin import deserialize_privkey
from ..ecc import (
    CURVE_ORDER,
    GENERATOR,
    POINT_AT_INFINITY,
    PRIVATE_KEY_BYTECOUNT,
    ECPrivkey,
    ECPubkey,
    InvalidECPointException,
    SignatureType,
    verify_message_with_address,
)
from ..util import randrange


class TestECC(unittest.TestCase):
    def test_crypto(self):
        for message in [
            b"Chancellor on brink of second bailout for banks",
            b"\xff" * 512,
        ]:
            self._do_test_crypto(message)

    def _do_test_crypto(self, message):
        pvk = randrange(GENERATOR.order())

        Pub = pvk * GENERATOR
        pubkey_c = Pub.get_public_key_bytes(compressed=True)

        eck = ECPrivkey(
            int.to_bytes(
                pvk, length=PRIVATE_KEY_BYTECOUNT, byteorder="big", signed=False
            )
        )

        enc = ECPubkey(pubkey_c).encrypt_message(message)
        dec = eck.decrypt_message(enc)
        self.assertEqual(message, dec)

        dec2 = eck.decrypt_message(enc)
        self.assertEqual(message, dec2)

        signature = eck.sign_message(message, True)
        pubkey = ECPubkey(pubkey_c)
        self.assertTrue(pubkey.verify_message(signature, message))

    def test_msg_signing(self):
        msg1 = b"Chancellor on brink of second bailout for banks"
        msg2 = b"Electrum"

        def sign_message_with_wif_privkey(wif_privkey, msg):
            txin_type, privkey, compressed = deserialize_privkey(wif_privkey)
            key = ECPrivkey(privkey)
            return key.sign_message(msg, compressed)

        sig1 = sign_message_with_wif_privkey(
            "L1TnU2zbNaAqMoVh65Cyvmcjzbrj41Gs9iTLcWbpJCMynXuap6UN", msg1
        )
        addr1 = "15hETetDmcXm1mM4sEf7U2KXC9hDHFMSzz"
        sig2 = sign_message_with_wif_privkey(
            "5Hxn5C4SQuiV6e62A1MtZmbSeQyrLFhu5uYks62pU5VBUygK2KD", msg2
        )
        addr2 = "1GPHVTY8UD9my6jyP4tb2TYJwUbDetyNC6"

        sig1_b64 = base64.b64encode(sig1)
        sig2_b64 = base64.b64encode(sig2)

        # NOTE: you cannot rely on exact binary patterns of signatures
        # produced by libsecp versus python ecdsa, etc. This is because nonces
        # may differ.  We ran into this when switching from Bitcoin Core libsecp
        # to Bitcoin ABC libsecp.  Amaury Sechet confirmed this to be true.
        # Mark Lundeberg suggested we still do binary exact matches from a set,
        # though, just to notice when nonces of the underlying lib change.
        # So.. the below test is has been updated to use a set.
        # self.assertEqual(sig1_b64, b'H/9jMOnj4MFbH3d7t4yCQ9i7DgZU/VZ278w3+ySv2F4yIsdqjsc5ng3kmN8OZAThgyfCZOQxZCWza9V5XzlVY0Y=')
        # self.assertEqual(sig2_b64, b'G84dmJ8TKIDKMT9qBRhpX2sNmR0y5t+POcYnFFJCs66lJmAs3T8A6Sbpx7KA6yTQ9djQMabwQXRrDomOkIKGn18=')
        #
        # Hardcoded sigs were generated with an old version of Python ECDSA by Electrum team.
        # New sigs generated with Bitcoin-ABC libsecp variant.  Both verify against each other ok
        # They just have different byte patterns we can expect due to different nonces used (both are to spec).
        accept_b64_1 = {
            # Generated by Electrum ABC 5.0.2 using python-ecdsa
            b"IANkoyMsQ9aqJcqAfuPUb1iL1DVzlEzqZKQ+V2DrL6dteQvvsqTWIk5APuVthkXDcOhvZ3fvhaW90ADgxzUyd0Y=",
            # Generated by Electrum ABC 5.0.2 using libsecp256k1
            b"H24keoFIst9/UaMfvFejNk52pRMK3xGaC784Dz4NC/sOLCSy3y5Jpf7+Pk5BDQuKnOP6+Fr68yD0acMLt3WQXQ8=",
            # Generated by Bitcoin ABC 0.23.2
            b"ICbt3+l7L3C9ANJp3I2UK9h3i1AyTJeqFadIAhKQUm8gJRw2nKV+eCcaMgzm5couc12Yba6U/6YTmmNWVzaKv+A=",
        }
        accept_b64_2 = {
            # Generated by Electrum ABC 5.0.2 using python-ecdsa
            b"G0dN2iKid9zT79uz4RLfw2nDSB1AE2JsmYtzUxM4YXhvQ2iZxFvs9teeExaopgGxwPvadRPmP4oEXTZt4P3Vwic=",
            # Generated by Electrum ABC 5.0.2 using libsecp256k1
            b"HBF0Y4u7KECCNB8rubGysV0ZFiYMevLhdBrYck7MFgZuNSv+DODii6YQ2HyKyHYsZ7Q6ZRjkaaMXDacdCAoQ63k=",
            # Generated by Bitcoin ABC 0.23.2
            b"G9jE+8cxPJV9HKeqDh8xgIE+isk/8/3Jf7GNAlfEsEgGa7mdegoAQmBGTtblfQ6v+ciz+xUEubUh9HY5lm+rtZ8=",
        }

        # does it match with one of our hard-coded sigs in the set?
        self.assertTrue(sig1_b64 in accept_b64_1)
        self.assertTrue(sig2_b64 in accept_b64_2)
        # can it verify its own sigs?
        self.assertTrue(verify_message_with_address(addr1, sig1, msg1))
        self.assertTrue(verify_message_with_address(addr2, sig2, msg2))
        # Can we verify the hardcoded sigs (this checks that the underlying ECC
        # libs basically are ok with either nonce being used)
        for sig in accept_b64_1:
            self.assertTrue(
                verify_message_with_address(addr1, base64.b64decode(sig), msg1)
            )
        for sig in accept_b64_2:
            self.assertTrue(
                verify_message_with_address(addr2, base64.b64decode(sig), msg2)
            )

        self.assertFalse(verify_message_with_address(addr1, b"wrong", msg1))
        # test for bad sigs for a message
        self.assertFalse(verify_message_with_address(addr1, sig2, msg1))
        self.assertFalse(verify_message_with_address(addr2, sig1, msg2))

    def test_legacy_msg_signing(self):
        """Test that we can use the legacy "Bitcoin Signed Message:\n" message magic."""
        msg = b"Chancellor on brink of second bailout for banks"
        addr = "15hETetDmcXm1mM4sEf7U2KXC9hDHFMSzz"

        txin_type, privkey, compressed = deserialize_privkey(
            "L1TnU2zbNaAqMoVh65Cyvmcjzbrj41Gs9iTLcWbpJCMynXuap6UN"
        )
        key = ECPrivkey(privkey)
        sig = key.sign_message(msg, compressed, sigtype=SignatureType.BITCOIN)

        accepted_signatures = {
            # Older core libsecp/python ecdsa nonce produces this deterministic signature
            b"H/9jMOnj4MFbH3d7t4yCQ9i7DgZU/VZ278w3+ySv2F4yIsdqjsc5ng3kmN8OZAThgyfCZOQxZCWza9V5XzlVY0Y=",
            # New Bitoin ABC libsecp nonce produces this deterministic signature
            b"IA+oq/uGz4kKA2bNgxPcM+T216abyUiBhofMg1J8fC5BLAbbIpF2toCHaO7/LQAxhQBtu5D6ROq1JjXiRwPAASg=",
        }
        self.assertTrue(base64.b64encode(sig) in accepted_signatures)

        for sig_ in accepted_signatures:
            self.assertTrue(
                verify_message_with_address(
                    address=addr,
                    sig65=base64.b64decode(sig_),
                    message=msg,
                    sigtype=SignatureType.BITCOIN,
                )
            )

    def test_point_infinity(self):
        G = ECPubkey(
            bytes.fromhex(
                "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
            )
        )
        self.assertEqual(G, GENERATOR)
        minusG = ECPubkey(
            bytes.fromhex(
                "0379be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
            )
        )
        uncompressed_minusG = ECPubkey(
            bytes.fromhex(
                "0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798b7c52588d95c3b9aa25b0403f1eef75702e84bb7597aabe663b82f6f04ef2777"
            )
        )

        for inf in (
            G + minusG,
            G + uncompressed_minusG,
            CURVE_ORDER * G,
            CURVE_ORDER * (5 * G),
            (CURVE_ORDER * G) * 5,
        ):
            self.assertTrue(inf.is_at_infinity())
            self.assertEqual(inf, POINT_AT_INFINITY)

        for not_inf in (
            G,
            (CURVE_ORDER - 2) * G,
            (CURVE_ORDER - 1) * G,
            (CURVE_ORDER + 1) * G,
        ):
            self.assertFalse(not_inf.is_at_infinity())
            self.assertNotEqual(not_inf, POINT_AT_INFINITY)

    def test_point_not_on_curve(self):
        with self.assertRaises(InvalidECPointException):
            ECPubkey(
                bytes.fromhex(
                    "030000000000000000000000000000000000000000000000000000000000000007"
                )
            )

    def test_ecc_sanity(self):
        G = GENERATOR
        self.assertEqual(G.order(), CURVE_ORDER)
        self.assertEqual(11 * G, 7 * G + 4 * G)
        self.assertEqual((CURVE_ORDER + 2) * G, 2 * G)
        self.assertEqual((CURVE_ORDER - 2) * G, -2 * G)
        self.assertNotEqual((CURVE_ORDER - 2) * G, (CURVE_ORDER - 1) * G)
        self.assertEqual(2 * G, POINT_AT_INFINITY + 2 * G)
        self.assertEqual(POINT_AT_INFINITY, 3 * G + (-3 * G))


if __name__ == "__main__":
    unittest.main()
